{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s_qNSzzyaCbD"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "jmjh290raIky"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J0Qjg6vuaHNt"
      },
      "source": [
        "# Neural machine translation with attention"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AOpGoE2T-YXS"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/nmt_with_attention\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003e\n",
        "    View on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/nmt_with_attention.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/nmt_with_attention.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/nmt_with_attention.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xh8WNEwYA3BW"
      },
      "source": [
        "This tutorial demonstrates how to train a sequence-to-sequence (seq2seq) model for Spanish-to-English translation roughly based on [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025v5) (Luong et al., 2015). \n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN%2Battention-words-spa.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThis tutorial: An encoder/decoder connected by attention.\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003c/table\u003e\n",
        "\n",
        "While this architecture is somewhat outdated, it is still a very useful project to work through to get a deeper understanding of sequence-to-sequence models and attention mechanisms (before going on to [Transformers](transformer.ipynb))."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CiwtNgENbx2g"
      },
      "source": [
        "\n",
        "\n",
        "This example assumes some knowledge of TensorFlow fundamentals below the level of a Keras layer:\n",
        "  * [Working with tensors](https://www.tensorflow.org/guide/tensor) directly\n",
        "  * [Writing custom `keras.Model`s and `keras.layers`](https://www.tensorflow.org/guide/keras/custom_layers_and_models)\n",
        "\n",
        "After training the model in this notebook, you will be able to input a Spanish sentence, such as \"*¿todavia estan en casa?*\", and return the English translation: \"*are you still at home?*\"\n",
        "\n",
        "The resulting model is exportable as a `tf.saved_model`, so it can be used in other TensorFlow environments.\n",
        "\n",
        "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
        "\n",
        "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n",
        "\n",
        "Note: This example takes approximately 10 minutes to run."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yAmSR1FaqKrl"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DGFTkuRvzWqc"
      },
      "outputs": [],
      "source": [
        "!pip install \"tensorflow-text\u003e=2.11\"\n",
        "!pip install einops"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tnxXKDjq3jEL"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "import typing\n",
        "from typing import Any, Tuple\n",
        "\n",
        "import einops\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as ticker\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_text as tf_text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l_yq8kvIqoqQ"
      },
      "source": [
        "This tutorial uses a lot of low level API's where it's easy to get shapes wrong. This class is used to check shapes throughout the tutorial.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KqFqKi4fqN9X"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "class ShapeChecker():\n",
        "  def __init__(self):\n",
        "    # Keep a cache of every axis-name seen\n",
        "    self.shapes = {}\n",
        "\n",
        "  def __call__(self, tensor, names, broadcast=False):\n",
        "    if not tf.executing_eagerly():\n",
        "      return\n",
        "\n",
        "    parsed = einops.parse_shape(tensor, names)\n",
        "\n",
        "    for name, new_dim in parsed.items():\n",
        "      old_dim = self.shapes.get(name, None)\n",
        "      \n",
        "      if (broadcast and new_dim == 1):\n",
        "        continue\n",
        "\n",
        "      if old_dim is None:\n",
        "        # If the axis name is new, add its length to the cache.\n",
        "        self.shapes[name] = new_dim\n",
        "        continue\n",
        "\n",
        "      if new_dim != old_dim:\n",
        "        raise ValueError(f\"Shape mismatch for dimension: '{name}'\\n\"\n",
        "                         f\"    found: {new_dim}\\n\"\n",
        "                         f\"    expected: {old_dim}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gjUROhJfH3ML"
      },
      "source": [
        "## The data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puE_K74DIE9W"
      },
      "source": [
        "The tutorial uses a language dataset provided by [Anki](http://www.manythings.org/anki/). This dataset contains language translation pairs in the format:\n",
        "\n",
        "```\n",
        "May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
        "```\n",
        "\n",
        "They have a variety of languages available, but this example uses the English-Spanish dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfodePkj3jEa"
      },
      "source": [
        "### Download and prepare the dataset\n",
        "\n",
        "For convenience, a copy of this dataset is hosted on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps you need to take to prepare the data:\n",
        "\n",
        "1. Add a *start* and *end* token to each sentence.\n",
        "2. Clean the sentences by removing special characters.\n",
        "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
        "4. Pad each sentence to a maximum length."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kRVATYOgJs1b"
      },
      "outputs": [],
      "source": [
        "# Download the file\n",
        "import pathlib\n",
        "\n",
        "path_to_zip = tf.keras.utils.get_file(\n",
        "    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n",
        "    extract=True)\n",
        "\n",
        "path_to_file = pathlib.Path(path_to_zip).parent/'spa-eng/spa.txt'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OHn4Dct23jEm"
      },
      "outputs": [],
      "source": [
        "def load_data(path):\n",
        "  text = path.read_text(encoding='utf-8')\n",
        "\n",
        "  lines = text.splitlines()\n",
        "  pairs = [line.split('\\t') for line in lines]\n",
        "\n",
        "  context = np.array([context for target, context in pairs])\n",
        "  target = np.array([target for target, context in pairs])\n",
        "\n",
        "  return target, context"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cTbSbBz55QtF"
      },
      "outputs": [],
      "source": [
        "target_raw, context_raw = load_data(path_to_file)\n",
        "print(context_raw[-1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lH_dPY8TRp3c"
      },
      "outputs": [],
      "source": [
        "print(target_raw[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgCLkfv5uO3d"
      },
      "source": [
        "### Create a tf.data dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PfVWx3WaI5Df"
      },
      "source": [
        "From these arrays of strings you can create a `tf.data.Dataset` of strings that shuffles and batches them efficiently:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3rZFgz69nMPa"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(context_raw)\n",
        "BATCH_SIZE = 64\n",
        "\n",
        "is_train = np.random.uniform(size=(len(target_raw),)) \u003c 0.8\n",
        "\n",
        "train_raw = (\n",
        "    tf.data.Dataset\n",
        "    .from_tensor_slices((context_raw[is_train], target_raw[is_train]))\n",
        "    .shuffle(BUFFER_SIZE)\n",
        "    .batch(BATCH_SIZE))\n",
        "val_raw = (\n",
        "    tf.data.Dataset\n",
        "    .from_tensor_slices((context_raw[~is_train], target_raw[~is_train]))\n",
        "    .shuffle(BUFFER_SIZE)\n",
        "    .batch(BATCH_SIZE))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qc6-NK1GtWQt"
      },
      "outputs": [],
      "source": [
        "for example_context_strings, example_target_strings in train_raw.take(1):\n",
        "  print(example_context_strings[:5])\n",
        "  print()\n",
        "  print(example_target_strings[:5])\n",
        "  break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zCoxLcuN3bwv"
      },
      "source": [
        "### Text preprocessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7kwdPcHvzz_a"
      },
      "source": [
        "One of the goals of this tutorial is to build a model that can be exported as a `tf.saved_model`. To make that exported model useful it should take `tf.string` inputs, and return `tf.string` outputs: All the text processing happens inside the model. Mainly using a `layers.TextVectorization` layer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EOQ5n55X4uDB"
      },
      "source": [
        "#### Standardization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "upKhKAMK4zzI"
      },
      "source": [
        "The model is dealing with multilingual text with a limited vocabulary. So it will be important to standardize the input text.\n",
        "\n",
        "The first step is Unicode normalization to split accented characters and replace compatibility characters with their ASCII equivalents.\n",
        "\n",
        "The `tensorflow_text` package contains a unicode normalize operation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mD0e-DWGQ2Vo"
      },
      "outputs": [],
      "source": [
        "example_text = tf.constant('¿Todavía está en casa?')\n",
        "\n",
        "print(example_text.numpy())\n",
        "print(tf_text.normalize_utf8(example_text, 'NFKD').numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hTllEjK6RSo"
      },
      "source": [
        "Unicode normalization will be the first step in the text standardization function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "chTF5N885F0P"
      },
      "outputs": [],
      "source": [
        "def tf_lower_and_split_punct(text):\n",
        "  # Split accented characters.\n",
        "  text = tf_text.normalize_utf8(text, 'NFKD')\n",
        "  text = tf.strings.lower(text)\n",
        "  # Keep space, a to z, and select punctuation.\n",
        "  text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')\n",
        "  # Add spaces around punctuation.\n",
        "  text = tf.strings.regex_replace(text, '[.?!,¿]', r' \\0 ')\n",
        "  # Strip whitespace.\n",
        "  text = tf.strings.strip(text)\n",
        "\n",
        "  text = tf.strings.join(['[START]', text, '[END]'], separator=' ')\n",
        "  return text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UREvDg3sEKYa"
      },
      "outputs": [],
      "source": [
        "print(example_text.numpy().decode())\n",
        "print(tf_lower_and_split_punct(example_text).numpy().decode())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4q-sKsSI7xRZ"
      },
      "source": [
        "#### Text Vectorization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6aKn8qd37abi"
      },
      "source": [
        "This standardization function will be wrapped up in a `tf.keras.layers.TextVectorization` layer which will handle the vocabulary extraction and conversion of input text to sequences of tokens."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eAY9k49G3jE_"
      },
      "outputs": [],
      "source": [
        "max_vocab_size = 5000\n",
        "\n",
        "context_text_processor = tf.keras.layers.TextVectorization(\n",
        "    standardize=tf_lower_and_split_punct,\n",
        "    max_tokens=max_vocab_size,\n",
        "    ragged=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7kbC6ODP8IK_"
      },
      "source": [
        "The `TextVectorization` layer and many other [Keras preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) have an `adapt` method. This method reads one epoch of the training data, and works a lot like `Model.fit`. This `adapt` method initializes the layer based on the data. Here it determines the vocabulary:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bmsI1Yql8FYe"
      },
      "outputs": [],
      "source": [
        "context_text_processor.adapt(train_raw.map(lambda context, target: context))\n",
        "\n",
        "# Here are the first 10 words from the vocabulary:\n",
        "context_text_processor.get_vocabulary()[:10]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9kGjIFjX8_Wp"
      },
      "source": [
        "That's the Spanish `TextVectorization` layer, now build and `.adapt()` the English one:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jlC4xuZnKLBS"
      },
      "outputs": [],
      "source": [
        "target_text_processor = tf.keras.layers.TextVectorization(\n",
        "    standardize=tf_lower_and_split_punct,\n",
        "    max_tokens=max_vocab_size,\n",
        "    ragged=True)\n",
        "\n",
        "target_text_processor.adapt(train_raw.map(lambda context, target: target))\n",
        "target_text_processor.get_vocabulary()[:10]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWQqlP_s9eIv"
      },
      "source": [
        "Now these layers can convert a batch of strings into a batch of token IDs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9KZxj8IrNZ9S"
      },
      "outputs": [],
      "source": [
        "example_tokens = context_text_processor(example_context_strings)\n",
        "example_tokens[:3, :]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AA9rUn9G9n78"
      },
      "source": [
        "The `get_vocabulary` method can be used to convert token IDs back to text:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "98g9rcxGQY0I"
      },
      "outputs": [],
      "source": [
        "context_vocab = np.array(context_text_processor.get_vocabulary())\n",
        "tokens = context_vocab[example_tokens[0].numpy()]\n",
        "' '.join(tokens)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ot0aCL9t-Ghi"
      },
      "source": [
        "The returned token IDs are zero-padded. This can easily be turned into a mask:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_jx4Or_eFRSz"
      },
      "outputs": [],
      "source": [
        "plt.subplot(1, 2, 1)\n",
        "plt.pcolormesh(example_tokens.to_tensor())\n",
        "plt.title('Token IDs')\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.pcolormesh(example_tokens.to_tensor() != 0)\n",
        "plt.title('Mask')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3O0B4XdFlRgc"
      },
      "source": [
        "### Process the dataset\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rVCuyuSp_whd"
      },
      "source": [
        "The `process_text` function below converts the `Datasets` of strings, into  0-padded tensors of token IDs. It also converts from a `(context, target)` pair to an `((context, target_in), target_out)` pair for training with `keras.Model.fit`. Keras expects `(inputs, labels)` pairs, the inputs are the `(context, target_in)` and the labels are `target_out`. The difference between `target_in` and `target_out` is that they are shifted by one step relative to eachother, so that at each location the label is the next token."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wk5tbZWQl5u1"
      },
      "outputs": [],
      "source": [
        "def process_text(context, target):\n",
        "  context = context_text_processor(context).to_tensor()\n",
        "  target = target_text_processor(target)\n",
        "  targ_in = target[:,:-1].to_tensor()\n",
        "  targ_out = target[:,1:].to_tensor()\n",
        "  return (context, targ_in), targ_out\n",
        "\n",
        "\n",
        "train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)\n",
        "val_ds = val_raw.map(process_text, tf.data.AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4iGi7X2m_tbM"
      },
      "source": [
        "Here is the first sequence of each, from the first batch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "woQBWAjLsJkr"
      },
      "outputs": [],
      "source": [
        "for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):\n",
        "  print(ex_context_tok[0, :10].numpy()) \n",
        "  print()\n",
        "  print(ex_tar_in[0, :10].numpy()) \n",
        "  print(ex_tar_out[0, :10].numpy()) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TNfHIF71ulLu"
      },
      "source": [
        "## The encoder/decoder\n",
        "\n",
        "The following diagrams shows an overview of the model. In both the encoder is on the left, the decoder is on the right. At each time-step the decoder's output is combined with the encoder's output, to predict the next word. \n",
        "\n",
        "The original [left] contains a few extra connections that are intentionally omitted from this tutorial's model [right], as they are generally unnecessary, and difficult to implement. Those missing connections are:\n",
        "\n",
        "1. Feeding the state from the encoder's RNN to the decoder's RNN\n",
        "2. Feeding the attention output back to the RNN's input.\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=380 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN+attention.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe original from \u003ca href=https://arxiv.org/abs/1508.04025v5\u003eEffective Approaches to Attention-based Neural Machine Translation\u003c/a\u003e\u003c/th\u003e\n",
        "  \u003cth colspan=1\u003eThis tutorial's model\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzQWx2saImMV"
      },
      "source": [
        "Before getting into it define constants for the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_a9uNz3-IrF-"
      },
      "outputs": [],
      "source": [
        "UNITS = 256"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blNgVbLSzpsr"
      },
      "source": [
        "### The encoder\n",
        "\n",
        "The goal of the encoder is to process the context sequence into a sequence of vectors that are useful for the decoder as it attempts to predict the next output for each timestep. Since the context sequence is constant, there is no restriction on how information can flow in the encoder, so use a bidirectional-RNN to do the processing:\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://tensorflow.org/images/tutorials/transformer/RNN-bidirectional.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eA bidirectional RNN\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003c/table\u003e\n",
        "\n",
        "The encoder:\n",
        "\n",
        "1. Takes a list of token IDs (from `context_text_processor`).\n",
        "3. Looks up an embedding vector for each token (Using a `layers.Embedding`).\n",
        "4. Processes the embeddings into a new sequence (Using a bidirectional `layers.GRU`).\n",
        "5. Returns the processed sequence. This will be passed to the attention head."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nZ2rI24i3jFg"
      },
      "outputs": [],
      "source": [
        "class Encoder(tf.keras.layers.Layer):\n",
        "  def __init__(self, text_processor, units):\n",
        "    super(Encoder, self).__init__()\n",
        "    self.text_processor = text_processor\n",
        "    self.vocab_size = text_processor.vocabulary_size()\n",
        "    self.units = units\n",
        "    \n",
        "    # The embedding layer converts tokens to vectors\n",
        "    self.embedding = tf.keras.layers.Embedding(self.vocab_size, units,\n",
        "                                               mask_zero=True)\n",
        "\n",
        "    # The RNN layer processes those vectors sequentially.\n",
        "    self.rnn = tf.keras.layers.Bidirectional(\n",
        "        merge_mode='sum',\n",
        "        layer=tf.keras.layers.GRU(units,\n",
        "                            # Return the sequence and state\n",
        "                            return_sequences=True,\n",
        "                            recurrent_initializer='glorot_uniform'))\n",
        "\n",
        "  def call(self, x):\n",
        "    shape_checker = ShapeChecker()\n",
        "    shape_checker(x, 'batch s')\n",
        "\n",
        "    # 2. The embedding layer looks up the embedding vector for each token.\n",
        "    x = self.embedding(x)\n",
        "    shape_checker(x, 'batch s units')\n",
        "\n",
        "    # 3. The GRU processes the sequence of embeddings.\n",
        "    x = self.rnn(x)\n",
        "    shape_checker(x, 'batch s units')\n",
        "\n",
        "    # 4. Returns the new sequence of embeddings.\n",
        "    return x\n",
        "\n",
        "  def convert_input(self, texts):\n",
        "    texts = tf.convert_to_tensor(texts)\n",
        "    if len(texts.shape) == 0:\n",
        "      texts = tf.convert_to_tensor(texts)[tf.newaxis]\n",
        "    context = self.text_processor(texts).to_tensor()\n",
        "    context = self(context)\n",
        "    return context"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D3SKkaQeGn-Q"
      },
      "source": [
        "Try it out:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60gSVh05Jl6l"
      },
      "outputs": [],
      "source": [
        "# Encode the input sequence.\n",
        "encoder = Encoder(context_text_processor, UNITS)\n",
        "ex_context = encoder(ex_context_tok)\n",
        "\n",
        "print(f'Context tokens, shape (batch, s): {ex_context_tok.shape}')\n",
        "print(f'Encoder output, shape (batch, s, units): {ex_context.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45xM_Gl1MgXY"
      },
      "source": [
        "### The attention layer\n",
        "\n",
        "The attention layer lets the decoder access the information extracted by the encoder. It computes a vector from the entire context sequence, and adds that to the decoder's output. \n",
        "\n",
        "The simplest way you could calculate a single vector from the entire sequence would be to take the average across the sequence (`layers.GlobalAveragePooling1D`). An attention layer is similar, but calculates a **weighted** average across the context sequence. Where the weights are calculated from the combination of context and \"query\" vectors.\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new-full.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Ql3ymqwD8LS"
      },
      "outputs": [],
      "source": [
        "class CrossAttention(tf.keras.layers.Layer):\n",
        "  def __init__(self, units, **kwargs):\n",
        "    super().__init__()\n",
        "    self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1, **kwargs)\n",
        "    self.layernorm = tf.keras.layers.LayerNormalization()\n",
        "    self.add = tf.keras.layers.Add()\n",
        "\n",
        "  def call(self, x, context):\n",
        "    shape_checker = ShapeChecker()\n",
        " \n",
        "    shape_checker(x, 'batch t units')\n",
        "    shape_checker(context, 'batch s units')\n",
        "\n",
        "    attn_output, attn_scores = self.mha(\n",
        "        query=x,\n",
        "        value=context,\n",
        "        return_attention_scores=True)\n",
        "    \n",
        "    shape_checker(x, 'batch t units')\n",
        "    shape_checker(attn_scores, 'batch heads t s')\n",
        "    \n",
        "    # Cache the attention scores for plotting later.\n",
        "    attn_scores = tf.reduce_mean(attn_scores, axis=1)\n",
        "    shape_checker(attn_scores, 'batch t s')\n",
        "    self.last_attention_weights = attn_scores\n",
        "\n",
        "    x = self.add([x, attn_output])\n",
        "    x = self.layernorm(x)\n",
        "\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7y7hjPkNMmHh"
      },
      "outputs": [],
      "source": [
        "attention_layer = CrossAttention(UNITS)\n",
        "\n",
        "# Attend to the encoded tokens\n",
        "embed = tf.keras.layers.Embedding(target_text_processor.vocabulary_size(),\n",
        "                                  output_dim=UNITS, mask_zero=True)\n",
        "ex_tar_embed = embed(ex_tar_in)\n",
        "\n",
        "result = attention_layer(ex_tar_embed, ex_context)\n",
        "\n",
        "print(f'Context sequence, shape (batch, s, units): {ex_context.shape}')\n",
        "print(f'Target sequence, shape (batch, t, units): {ex_tar_embed.shape}')\n",
        "print(f'Attention result, shape (batch, t, units): {result.shape}')\n",
        "print(f'Attention weights, shape (batch, t, s):    {attention_layer.last_attention_weights.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vx9fUhi3Pmwp"
      },
      "source": [
        "The attention weights will sum to `1` over the context sequence, at each location in the target sequence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zxyR7cmQPn9P"
      },
      "outputs": [],
      "source": [
        "attention_layer.last_attention_weights[0].numpy().sum(axis=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AagyXMH-Jhqt"
      },
      "source": [
        "\n",
        "\n",
        "Here are the attention weights across the context sequences at `t=0`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rqr8XGsAJlf6"
      },
      "outputs": [],
      "source": [
        "attention_weights = attention_layer.last_attention_weights\n",
        "mask=(ex_context_tok != 0).numpy()\n",
        "\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.pcolormesh(mask*attention_weights[:, 0, :])\n",
        "plt.title('Attention weights')\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.pcolormesh(mask)\n",
        "plt.title('Mask');\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Eil-C_NN1rp"
      },
      "source": [
        "Because of the small-random initialization the attention weights are initially all close to `1/(sequence_length)`. The model will learn to make these less uniform as training progresses."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aQ638eHN4iCK"
      },
      "source": [
        "### The decoder\n",
        "\n",
        "The decoder's job is to generate predictions for the next token at each location in the target sequence.\n",
        "\n",
        "1. It looks up embeddings for each token in the target sequence.\n",
        "2. It uses an RNN to process the target sequence, and keep track of what it has generated so far.\n",
        "3. It uses RNN output as the \"query\" to the attention layer, when attending to the encoder's output.\n",
        "4. At each location in the output it predicts the next token.\n",
        "\n",
        "When training, the model predicts the next word at each location. So it's important that the information only flows in one direction through the model. The decoder uses a unidirectional (not bidirectional) RNN to process the target sequence.\n",
        "\n",
        "When running inference with this model it produces one word at a time, and those are fed back into the model.\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://tensorflow.org/images/tutorials/transformer/RNN.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eA unidirectional RNN\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pZsQJMqNmg_L"
      },
      "source": [
        "Here is the `Decoder` class' initializer. The initializer creates all the necessary layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "erYvHIgAl8kh"
      },
      "outputs": [],
      "source": [
        "class Decoder(tf.keras.layers.Layer):\n",
        "  @classmethod\n",
        "  def add_method(cls, fun):\n",
        "    setattr(cls, fun.__name__, fun)\n",
        "    return fun\n",
        "\n",
        "  def __init__(self, text_processor, units):\n",
        "    super(Decoder, self).__init__()\n",
        "    self.text_processor = text_processor\n",
        "    self.vocab_size = text_processor.vocabulary_size()\n",
        "    self.word_to_id = tf.keras.layers.StringLookup(\n",
        "        vocabulary=text_processor.get_vocabulary(),\n",
        "        mask_token='', oov_token='[UNK]')\n",
        "    self.id_to_word = tf.keras.layers.StringLookup(\n",
        "        vocabulary=text_processor.get_vocabulary(),\n",
        "        mask_token='', oov_token='[UNK]',\n",
        "        invert=True)\n",
        "    self.start_token = self.word_to_id('[START]')\n",
        "    self.end_token = self.word_to_id('[END]')\n",
        "\n",
        "    self.units = units\n",
        "\n",
        "\n",
        "    # 1. The embedding layer converts token IDs to vectors\n",
        "    self.embedding = tf.keras.layers.Embedding(self.vocab_size,\n",
        "                                               units, mask_zero=True)\n",
        "\n",
        "    # 2. The RNN keeps track of what's been generated so far.\n",
        "    self.rnn = tf.keras.layers.GRU(units,\n",
        "                                   return_sequences=True,\n",
        "                                   return_state=True,\n",
        "                                   recurrent_initializer='glorot_uniform')\n",
        "\n",
        "    # 3. The RNN output will be the query for the attention layer.\n",
        "    self.attention = CrossAttention(units)\n",
        "\n",
        "    # 4. This fully connected layer produces the logits for each\n",
        "    # output token.\n",
        "    self.output_layer = tf.keras.layers.Dense(self.vocab_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sd8-nRNzFR8x"
      },
      "source": [
        "#### Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UPnaw583CpnY"
      },
      "source": [
        "Next, the `call` method, takes 3 arguments:\n",
        "\n",
        "* `inputs` -  a `context, x` pair where:\n",
        "  * `context` - is the context from the encoder's output.\n",
        "  * `x` - is the target sequence input.\n",
        "* `state` - Optional, the previous `state` output from the decoder (the internal state of the decoder's RNN). Pass the state from a previous run to continue generating text where you left off.\n",
        "* `return_state` - [Default: False] - Set this to `True` to return the RNN state. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PJOi5btHAPNK"
      },
      "outputs": [],
      "source": [
        "@Decoder.add_method\n",
        "def call(self,\n",
        "         context, x,\n",
        "         state=None,\n",
        "         return_state=False):  \n",
        "  shape_checker = ShapeChecker()\n",
        "  shape_checker(x, 'batch t')\n",
        "  shape_checker(context, 'batch s units')\n",
        "\n",
        "  # 1. Lookup the embeddings\n",
        "  x = self.embedding(x)\n",
        "  shape_checker(x, 'batch t units')\n",
        "\n",
        "  # 2. Process the target sequence.\n",
        "  x, state = self.rnn(x, initial_state=state)\n",
        "  shape_checker(x, 'batch t units')\n",
        "\n",
        "  # 3. Use the RNN output as the query for the attention over the context.\n",
        "  x = self.attention(x, context)\n",
        "  self.last_attention_weights = self.attention.last_attention_weights\n",
        "  shape_checker(x, 'batch t units')\n",
        "  shape_checker(self.last_attention_weights, 'batch t s')\n",
        "\n",
        "  # Step 4. Generate logit predictions for the next token.\n",
        "  logits = self.output_layer(x)\n",
        "  shape_checker(logits, 'batch t target_vocab_size')\n",
        "\n",
        "  if return_state:\n",
        "    return logits, state\n",
        "  else:\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E1-mLAcUEXpK"
      },
      "source": [
        "That will be sufficient for training. Create an instance of the decoder to test out:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4ZUMbYXIEVeA"
      },
      "outputs": [],
      "source": [
        "decoder = Decoder(target_text_processor, UNITS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SFWaI4wqzt4t"
      },
      "source": [
        "In training you'll use the decoder like this:\n",
        "\n",
        "Given the context and target tokens, for each target token it predicts the next target token. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5YM-lD7bzx18"
      },
      "outputs": [],
      "source": [
        "logits = decoder(ex_context, ex_tar_in)\n",
        "\n",
        "print(f'encoder output shape: (batch, s, units) {ex_context.shape}')\n",
        "print(f'input target tokens shape: (batch, t) {ex_tar_in.shape}')\n",
        "print(f'logits shape shape: (batch, target_vocabulary_size) {logits.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zhS_tbk7VQkX"
      },
      "source": [
        "#### Inference\n",
        "\n",
        "To use it for inference you'll need a couple more methods:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SPm12cnIVRQr"
      },
      "outputs": [],
      "source": [
        "@Decoder.add_method\n",
        "def get_initial_state(self, context):\n",
        "  batch_size = tf.shape(context)[0]\n",
        "  start_tokens = tf.fill([batch_size, 1], self.start_token)\n",
        "  done = tf.zeros([batch_size, 1], dtype=tf.bool)\n",
        "  embedded = self.embedding(start_tokens)\n",
        "  return start_tokens, done, self.rnn.get_initial_state(embedded)[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TzeOhpBvVS5L"
      },
      "outputs": [],
      "source": [
        "@Decoder.add_method\n",
        "def tokens_to_text(self, tokens):\n",
        "  words = self.id_to_word(tokens)\n",
        "  result = tf.strings.reduce_join(words, axis=-1, separator=' ')\n",
        "  result = tf.strings.regex_replace(result, '^ *\\[START\\] *', '')\n",
        "  result = tf.strings.regex_replace(result, ' *\\[END\\] *$', '')\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v6ildnz_V1MA"
      },
      "outputs": [],
      "source": [
        "@Decoder.add_method\n",
        "def get_next_token(self, context, next_token, done, state, temperature = 0.0):\n",
        "  logits, state = self(\n",
        "    context, next_token,\n",
        "    state = state,\n",
        "    return_state=True) \n",
        "  \n",
        "  if temperature == 0.0:\n",
        "    next_token = tf.argmax(logits, axis=-1)\n",
        "  else:\n",
        "    logits = logits[:, -1, :]/temperature\n",
        "    next_token = tf.random.categorical(logits, num_samples=1)\n",
        "\n",
        "  # If a sequence produces an `end_token`, set it `done`\n",
        "  done = done | (next_token == self.end_token)\n",
        "  # Once a sequence is done it only produces 0-padding.\n",
        "  next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)\n",
        "  \n",
        "  return next_token, done, state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9WiXLrVs-FTE"
      },
      "source": [
        "With those extra functions, you can write a generation loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SuehagxL-JBZ"
      },
      "outputs": [],
      "source": [
        "# Setup the loop variables.\n",
        "next_token, done, state = decoder.get_initial_state(ex_context)\n",
        "tokens = []\n",
        "\n",
        "for n in range(10):\n",
        "  # Run one step.\n",
        "  next_token, done, state = decoder.get_next_token(\n",
        "      ex_context, next_token, done, state, temperature=1.0)\n",
        "  # Add the token to the output.\n",
        "  tokens.append(next_token)\n",
        "\n",
        "# Stack all the tokens together.\n",
        "tokens = tf.concat(tokens, axis=-1) # (batch, t)\n",
        "\n",
        "# Convert the tokens back to a a string\n",
        "result = decoder.tokens_to_text(tokens)\n",
        "result[:3].numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ALTdqCMLGSY"
      },
      "source": [
        "Since the model's untrained, it outputs items from the vocabulary almost uniformly at random."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B6xyru86m914"
      },
      "source": [
        "## The model\n",
        "\n",
        "Now that you have all the model components, combine them to build the model for training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WWIyuy71TkJT"
      },
      "outputs": [],
      "source": [
        "class Translator(tf.keras.Model):\n",
        "  @classmethod\n",
        "  def add_method(cls, fun):\n",
        "    setattr(cls, fun.__name__, fun)\n",
        "    return fun\n",
        "\n",
        "  def __init__(self, units,\n",
        "               context_text_processor,\n",
        "               target_text_processor):\n",
        "    super().__init__()\n",
        "    # Build the encoder and decoder\n",
        "    encoder = Encoder(context_text_processor, units)\n",
        "    decoder = Decoder(target_text_processor, units)\n",
        "\n",
        "    self.encoder = encoder\n",
        "    self.decoder = decoder\n",
        "\n",
        "  def call(self, inputs):\n",
        "    context, x = inputs\n",
        "    context = self.encoder(context)\n",
        "    logits = self.decoder(context, x)\n",
        "\n",
        "    #TODO(b/250038731): remove this\n",
        "    try:\n",
        "      # Delete the keras mask, so keras doesn't scale the loss+accuracy. \n",
        "      del logits._keras_mask\n",
        "    except AttributeError:\n",
        "      pass\n",
        "\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5rPi0FkS2iA5"
      },
      "source": [
        "During training the model will be used like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8vhjTh84K6Mg"
      },
      "outputs": [],
      "source": [
        "model = Translator(UNITS, context_text_processor, target_text_processor)\n",
        "\n",
        "logits = model((ex_context_tok, ex_tar_in))\n",
        "\n",
        "print(f'Context tokens, shape: (batch, s, units) {ex_context_tok.shape}')\n",
        "print(f'Target tokens, shape: (batch, t) {ex_tar_in.shape}')\n",
        "print(f'logits, shape: (batch, t, target_vocabulary_size) {logits.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_ch_71VbIRfK"
      },
      "source": [
        "### Train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8FmzjGmprVmE"
      },
      "source": [
        "For training, you'll want to implement your own masked loss and accuracy functions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WmTHr5iV3jFr"
      },
      "outputs": [],
      "source": [
        "def masked_loss(y_true, y_pred):\n",
        "    # Calculate the loss for each item in the batch.\n",
        "    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "        from_logits=True, reduction='none')\n",
        "    loss = loss_fn(y_true, y_pred)\n",
        "\n",
        "    # Mask off the losses on padding.\n",
        "    mask = tf.cast(y_true != 0, loss.dtype)\n",
        "    loss *= mask\n",
        "\n",
        "    # Return the total.\n",
        "    return tf.reduce_sum(loss)/tf.reduce_sum(mask)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nRB1CTmQWOIL"
      },
      "outputs": [],
      "source": [
        "def masked_acc(y_true, y_pred):\n",
        "    # Calculate the loss for each item in the batch.\n",
        "    y_pred = tf.argmax(y_pred, axis=-1)\n",
        "    y_pred = tf.cast(y_pred, y_true.dtype)\n",
        "    \n",
        "    match = tf.cast(y_true == y_pred, tf.float32)\n",
        "    mask = tf.cast(y_true != 0, tf.float32)\n",
        "    \n",
        "    return tf.reduce_sum(match)/tf.reduce_sum(mask)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f32GuAhw2nXm"
      },
      "source": [
        "Configure the model for training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9g0DRRvm3l9X"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer='adam',\n",
        "              loss=masked_loss, \n",
        "              metrics=[masked_acc, masked_loss])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5DWLI3pssjnx"
      },
      "source": [
        "The model is randomly initialized, and should give roughly uniform output probabilities. So it's easy to predict what the initial values of the metrics should be:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BuP3_LFENMJG"
      },
      "outputs": [],
      "source": [
        "vocab_size = 1.0 * target_text_processor.vocabulary_size()\n",
        "\n",
        "{\"expected_loss\": tf.math.log(vocab_size).numpy(),\n",
        " \"expected_acc\": 1/vocab_size}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "frVba49Usd0Z"
      },
      "source": [
        "That should roughly match the values returned by running a few steps of evaluation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8rJITfxEsHKR"
      },
      "outputs": [],
      "source": [
        "model.evaluate(val_ds, steps=20, return_dict=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BQd_esVVoSf3"
      },
      "outputs": [],
      "source": [
        "history = model.fit(\n",
        "    train_ds.repeat(), \n",
        "    epochs=100,\n",
        "    steps_per_epoch = 100,\n",
        "    validation_data=val_ds,\n",
        "    validation_steps = 20,\n",
        "    callbacks=[\n",
        "        tf.keras.callbacks.EarlyStopping(patience=3)])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "38rLdlmtQHCm"
      },
      "outputs": [],
      "source": [
        "plt.plot(history.history['loss'], label='loss')\n",
        "plt.plot(history.history['val_loss'], label='val_loss')\n",
        "plt.ylim([0, max(plt.ylim())])\n",
        "plt.xlabel('Epoch #')\n",
        "plt.ylabel('CE/token')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KkhXRASNG80_"
      },
      "outputs": [],
      "source": [
        "plt.plot(history.history['masked_acc'], label='accuracy')\n",
        "plt.plot(history.history['val_masked_acc'], label='val_accuracy')\n",
        "plt.ylim([0, max(plt.ylim())])\n",
        "plt.xlabel('Epoch #')\n",
        "plt.ylabel('CE/token')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mU3Ce8M6I3rz"
      },
      "source": [
        "### Translate\n",
        "\n",
        "Now that the model is trained, implement a function to execute the full `text =\u003e text` translation. This code is basically identical to the [inference example](#inference) in the [decoder section](#the_decoder), but this also captures the attention weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mmgYPCVgEwp_"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "@Translator.add_method\n",
        "def translate(self,\n",
        "              texts, *,\n",
        "              max_length=50,\n",
        "              temperature=0.0):\n",
        "  # Process the input texts\n",
        "  context = self.encoder.convert_input(texts)\n",
        "  batch_size = tf.shape(texts)[0]\n",
        "\n",
        "  # Setup the loop inputs\n",
        "  tokens = []\n",
        "  attention_weights = []\n",
        "  next_token, done, state = self.decoder.get_initial_state(context)\n",
        "\n",
        "  for _ in range(max_length):\n",
        "    # Generate the next token\n",
        "    next_token, done, state = self.decoder.get_next_token(\n",
        "        context, next_token, done,  state, temperature)\n",
        "        \n",
        "    # Collect the generated tokens\n",
        "    tokens.append(next_token)\n",
        "    attention_weights.append(self.decoder.last_attention_weights)\n",
        "    \n",
        "    if tf.executing_eagerly() and tf.reduce_all(done):\n",
        "      break\n",
        "\n",
        "  # Stack the lists of tokens and attention weights.\n",
        "  tokens = tf.concat(tokens, axis=-1)   # t*[(batch 1)] -\u003e (batch, t)\n",
        "  self.last_attention_weights = tf.concat(attention_weights, axis=1)  # t*[(batch 1 s)] -\u003e (batch, t s)\n",
        "\n",
        "  result = self.decoder.tokens_to_text(tokens)\n",
        "  return result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U4XufRntbbva"
      },
      "source": [
        "Here are the two helper methods, used above, to convert tokens to text, and to get the next token:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E5hqvbR5FUCD"
      },
      "outputs": [],
      "source": [
        "result = model.translate(['¿Todavía está en casa?']) # Are you still home\n",
        "result[0].numpy().decode()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wQ1iU63cVgfs"
      },
      "source": [
        "Use that to generate the attention plot:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s5hQWlbN3jGF"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "@Translator.add_method\n",
        "def plot_attention(self, text, **kwargs):\n",
        "  assert isinstance(text, str)\n",
        "  output = self.translate([text], **kwargs)\n",
        "  output = output[0].numpy().decode()\n",
        "\n",
        "  attention = self.last_attention_weights[0]\n",
        "\n",
        "  context = tf_lower_and_split_punct(text)\n",
        "  context = context.numpy().decode().split()\n",
        "\n",
        "  output = tf_lower_and_split_punct(output)\n",
        "  output = output.numpy().decode().split()[1:]\n",
        "\n",
        "  fig = plt.figure(figsize=(10, 10))\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "\n",
        "  ax.matshow(attention, cmap='viridis', vmin=0.0)\n",
        "\n",
        "  fontdict = {'fontsize': 14}\n",
        "\n",
        "  ax.set_xticklabels([''] + context, fontdict=fontdict, rotation=90)\n",
        "  ax.set_yticklabels([''] + output, fontdict=fontdict)\n",
        "\n",
        "  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "\n",
        "  ax.set_xlabel('Input text')\n",
        "  ax.set_ylabel('Output text')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rrGawQv2eiA4"
      },
      "outputs": [],
      "source": [
        "model.plot_attention('¿Todavía está en casa?') # Are you still home"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JHBdOf9duumm"
      },
      "source": [
        "Translate a few more sentences and plot them:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "flT0VlQZK11s"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "# This is my life.\n",
        "model.plot_attention('Esta es mi vida.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t-fPYP_9K8xa"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        " # Try to find out.'\n",
        "model.plot_attention('Tratar de descubrir.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rA3xI3NzrRJt"
      },
      "source": [
        "The short sentences often work well, but if the input is too long the model literally loses focus and stops providing reasonable predictions. There are two main reasons for this:\n",
        "\n",
        "1. The model was trained with teacher-forcing feeding the correct token at each step, regardless of the model's predictions. The model could be made more robust if it were sometimes fed its own predictions.\n",
        "2. The model only has access to its previous output through the RNN state. If the RNN state looses track of where it was in the context sequence there's no way for the model to recover. [Transformers](transformer.ipynb) improve on this by letting the decoder look at what it has output so far."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vtz6QBoGWqT2"
      },
      "source": [
        "The raw data is sorted by length, so try translating the longest sequence:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-FUHFLEvSMbG"
      },
      "outputs": [],
      "source": [
        "long_text = context_raw[-1]\n",
        "\n",
        "import textwrap\n",
        "print('Expected output:\\n', '\\n'.join(textwrap.wrap(target_raw[-1])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lDa_8NaN_RUy"
      },
      "outputs": [],
      "source": [
        "model.plot_attention(long_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PToqG3GiIUPM"
      },
      "source": [
        "The `translate` function works on batches, so if you have multiple texts to translate you can pass them all at once, which is much more efficient than translating them one at a time:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1-FLCjBVEMXL"
      },
      "outputs": [],
      "source": [
        "inputs = [\n",
        "    'Hace mucho frio aqui.', # \"It's really cold here.\"\n",
        "    'Esta es mi vida.', # \"This is my life.\"\n",
        "    'Su cuarto es un desastre.' # \"His room is a mess\"\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sT68i4jYEQ7q"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "for t in inputs:\n",
        "  print(model.translate([t])[0].numpy().decode())\n",
        "\n",
        "print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hd2rgyHwVVrv"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = model.translate(inputs)\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uvhMqIw26Bwd"
      },
      "source": [
        "So overall this text generation function mostly gets the job done, but so you've only used it here in python with eager execution. Let's try to export it next:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X4POAuUgLxLv"
      },
      "source": [
        "### Export"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S-6cFyqeUPQm"
      },
      "source": [
        "If you want to export this model you'll need to wrap the `translate` method in a `tf.function`. That implementation will get the job done:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fNhGwQaVKIAy"
      },
      "outputs": [],
      "source": [
        "class Export(tf.Module):\n",
        "  def __init__(self, model):\n",
        "    self.model = model\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])\n",
        "  def translate(self, inputs):\n",
        "    return self.model.translate(inputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5Tjqs9FzNwW5"
      },
      "outputs": [],
      "source": [
        "export = Export(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fkccvHDvXCa8"
      },
      "source": [
        "Run the `tf.function` once to compile it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_NzrixLvVBjQ"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "_ = export.translate(tf.constant(inputs))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "USJdu00tVFbd"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = export.translate(tf.constant(inputs))\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NP2dNtEXJPEL"
      },
      "source": [
        "Now that the function has been traced it can be exported using `saved_model.save`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OyvxT5V0_X5B"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "tf.saved_model.save(export, 'translator',\n",
        "                    signatures={'serving_default': export.translate})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-I0j3i3ekOba"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "reloaded = tf.saved_model.load('translator')\n",
        "_ = reloaded.translate(tf.constant(inputs)) #warmup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GXZF__FZXJCm"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = reloaded.translate(tf.constant(inputs))\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pgg3P757O5rw"
      },
      "source": [
        "#### [Optional] Use a dynamic loop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3230LfyRIJQV"
      },
      "source": [
        "It's worth noting that this initial implementation is not optimal. It uses a python loop:\n",
        "\n",
        "```\n",
        "for _ in range(max_length):\n",
        "  ...\n",
        "  if tf.executing_eagerly() and tf.reduce_all(done):\n",
        "    break\n",
        "```\n",
        "\n",
        "The python loop is relatively simple but when `tf.function` converts this to a graph, it **statically unrolls** that loop. Unrolling the loop has two disadvantages:\n",
        "\n",
        "1. It makes `max_length` copies of the loop body. So the generated graphs take longer to build, save and load.\n",
        "1. You have to choose a fixed value for the `max_length`. \n",
        "1. You can't `break` from a statically unrolled loop. The `tf.function`\n",
        "  version will run the full `max_length` iterations on every call.\n",
        "  That's why the `break` only works with eager execution. This is\n",
        "  still marginally faster than eager execution, but not as fast as it could be.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zPRJp4TRJx_n"
      },
      "source": [
        "To fix these shortcomings, the `translate_dynamic` method, below, uses a tensorflow loop:\n",
        "\n",
        "```\n",
        "for t in tf.range(max_length):\n",
        "  ...\n",
        "  if tf.reduce_all(done):\n",
        "      break\n",
        "```\n",
        "\n",
        "It looks like a python loop, but when you use a tensor as the input to a `for` loop (or the condition of a `while` loop) `tf.function` converts it to a dynamic loop using operations like `tf.while_loop`. \n",
        "\n",
        "There's no need for a `max_length` here it's just in case the model gets stuck generating a loop like: `the united states of the united states of the united states...`.\n",
        "\n",
        "On the down side, to accumulate tokens from this dynamic loop you can't just append them to a python `list`, you need to use a `tf.TensorArray`:\n",
        "\n",
        "```\n",
        "tokens = tf.TensorArray(tf.int64, size=1, dynamic_size=True)\n",
        "...\n",
        "for t in tf.range(max_length):\n",
        "    ...\n",
        "    tokens = tokens.write(t, next_token) # next_token shape is (batch, 1)\n",
        "  ...\n",
        "  tokens = tokens.stack()\n",
        "  tokens = einops.rearrange(tokens, 't batch 1 -\u003e batch t')\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTmISp4SRo5U"
      },
      "source": [
        "This version of the code can be quite a bit more efficient:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EbQpyYs13jF_"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "@Translator.add_method\n",
        "def translate(self,\n",
        "              texts,\n",
        "              *,\n",
        "              max_length=500,\n",
        "              temperature=tf.constant(0.0)):\n",
        "  shape_checker = ShapeChecker()\n",
        "  context = self.encoder.convert_input(texts)\n",
        "  batch_size = tf.shape(context)[0]\n",
        "  shape_checker(context, 'batch s units')\n",
        "\n",
        "  next_token, done, state = self.decoder.get_initial_state(context)\n",
        "\n",
        "  # initialize the accumulator\n",
        "  tokens = tf.TensorArray(tf.int64, size=1, dynamic_size=True)\n",
        "\n",
        "  for t in tf.range(max_length):\n",
        "    # Generate the next token\n",
        "    next_token, done, state = self.decoder.get_next_token(\n",
        "        context, next_token, done, state, temperature)\n",
        "    shape_checker(next_token, 'batch t1')\n",
        "\n",
        "    # Collect the generated tokens\n",
        "    tokens = tokens.write(t, next_token)\n",
        "\n",
        "    # if all the sequences are done, break\n",
        "    if tf.reduce_all(done):\n",
        "      break\n",
        "\n",
        "  # Convert the list of generated token ids to a list of strings.\n",
        "  tokens = tokens.stack()\n",
        "  shape_checker(tokens, 't batch t1')\n",
        "  tokens = einops.rearrange(tokens, 't batch 1 -\u003e batch t')\n",
        "  shape_checker(tokens, 'batch t')\n",
        "\n",
        "  text = self.decoder.tokens_to_text(tokens)\n",
        "  shape_checker(text, 'batch')\n",
        "\n",
        "  return text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AJ_NznOgZTxC"
      },
      "source": [
        "With eager execution this implementation performs on par with the original:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JRh66y-YYeBw"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = model.translate(inputs)\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l6B8W4_MZdX0"
      },
      "source": [
        "But when you wrap it in a `tf.function` you'll notice two differences."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EQlrhWWrUhgT"
      },
      "outputs": [],
      "source": [
        "class Export(tf.Module):\n",
        "  def __init__(self, model):\n",
        "    self.model = model\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])\n",
        "  def translate(self, inputs):\n",
        "    return self.model.translate(inputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pH8yyGHvUmti"
      },
      "outputs": [],
      "source": [
        "export = Export(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZnOJvIsvUwBL"
      },
      "source": [
        "First, it's much quicker to trace, since it only creates one copy of the loop body:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_CaEbHkwEa1S"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "_ = export.translate(inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2ABEwtKIZ6eE"
      },
      "source": [
        "The `tf.function` is much faster than running with eager execution, and on small inputs it's often several times faster than the unrolled version, because it can break out of the loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d5VdCLxPYrpz"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = export.translate(inputs)\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3DDmofICJdx0"
      },
      "source": [
        "So save this version as well:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eCg7kRq6FVl3"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "tf.saved_model.save(export, 'dynamic_translator',\n",
        "                    signatures={'serving_default': export.translate})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zrpzxL2vFVl3"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "reloaded = tf.saved_model.load('dynamic_translator')\n",
        "_ = reloaded.translate(tf.constant(inputs)) #warmup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5TjSwrCEFVl3"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "result = reloaded.translate(tf.constant(inputs))\n",
        "\n",
        "print(result[0].numpy().decode())\n",
        "print(result[1].numpy().decode())\n",
        "print(result[2].numpy().decode())\n",
        "print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RTe5P5ioMJwN"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
        "* Experiment with training on a larger dataset, or using more epochs.\n",
        "* Try the [transformer tutorial](transformer.ipynb) which implements a similar translation task but uses transformer layers instead of RNNs. This version also uses a `text.BertTokenizer` to implement word-piece tokenization.\n",
        "* Visit the [`tensorflow_addons.seq2seq` tutorial](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt), which demonstrates a higher-level functionality for implementing this sort of sequence-to-sequence model, such as `seq2seq.BeamSearchDecoder`."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "nmt_with_attention.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
